#include "prediction.hpp"

using namespace std;
SelectWithPrediction::SelectWithPrediction() : comparison_cnt_dag(0), comparison_cnt_sort(0) {}

int SelectWithPrediction::BinarySearch(const std::vector<int>& arr, int val, bool flag){
    int left = 0, right = arr.size();
    while (left < right) {
        int mid = (left + right) / 2;
        comparison_cnt_dag++;
        if (arr[mid] < val) {
            left = mid + 1;
        } else {
            right = mid;
        }
    }
    return left;
}

int SelectWithPrediction::RandomizedSelect(std::vector<int> &A, int k, std::vector<bool>& active, 
                                           std::vector<std::vector<int>>& chains, const DirectedAcyclicGraph& dag) {
    if (A.size() == 1) {
        return A[0];
    }
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_int_distribution<> dist(0, A.size() - 1);
    int pivotIndex;
    int pivot;
    int pivot_count;
    std::unordered_set<int> left_set, right_set;
    std::vector<std::vector<int>> left_chains, right_chains;
    do {
        left_chains.clear();
        right_chains.clear();
        left_set.clear();
        right_set.clear();
        pivotIndex = dist(gen);
        pivot = A[pivotIndex];
        std::vector<int> left_chain, right_chain;
        dag.Reachable(pivot, left_set, right_set, active);
        for (const auto &chain : chains) { 
            left_chain.clear();
            right_chain.clear();
            int position = BinarySearch(chain, pivot, 1);
            if (position < chain.size() && chain[position] == pivot) {
                left_chain.insert(left_chain.end(), chain.begin(), chain.begin() + position);
                right_chain.insert(right_chain.end(), chain.begin() + position + 1, chain.end());
                // pivot_count++;
            } else {
                left_chain.insert(left_chain.end(), chain.begin(), chain.begin() + position);
                right_chain.insert(right_chain.end(), chain.begin() + position, chain.end());
            }
            left_set.insert(left_chain.begin(), left_chain.end());
            right_set.insert(right_chain.begin(), right_chain.end());
            
            if(!left_chain.empty()){
                left_chains.push_back(left_chain);
            }
            if(!right_chain.empty()){
                right_chains.push_back(right_chain);
            }
        }
    } while (left_set.size() < 0.25 * A.size() || left_set.size() > 0.75 * A.size());
    active[pivot] = false;
    if (left_set.size() >= k) {
        for (int v : right_set) {
            active[v] = false;
        }
        std::vector<int> left(left_set.begin(), left_set.end());
        return RandomizedSelect(left, k, active, left_chains, dag);
    }
    else if (left_set.size() + 1 < k) {
        for (int v : left_set) {
            active[v] = false;
        }
        std::vector<int> right(right_set.begin(), right_set.end());
        return RandomizedSelect(right, k - (left_set.size() + 1), active, right_chains, dag);
    }
    else {
        return pivot;
    }
}



int SelectWithPrediction::DeterministicSelect(
    std::vector<int>& A, int k, std::vector<bool>& active, 
    const std::vector<std::vector<int>> &chains, 
    const DirectedAcyclicGraph& dag, bool flag){

    recursive_call++;
    if(A.size() <= 5){
        std::sort(A.begin(), A.end());
        recursive_call--;
        return A[k - 1];
    }
    int n = A.size();
    std::vector<int> medians;
    std::vector<bool> local_active(dag.GetNumNodes(), false);

    medians = dag.Median(active, local_active);
    int cnt = 0;
    for(int i = 0 ; i < local_active.size();i++){
        if(local_active[i]){
            cnt++;
        }
    }
    comparison_cnt_sort += dag.GetComparisonCountSort();

    int chain_size = 0;
    std::vector<std::vector<int>> local_chains;
    for(const auto &chain : chains){
        std::vector<int> local_chain;
        for(const auto v : chain){
            if(local_active[v]){
                local_chain.push_back(v);
                chain_size++;
            }
        }
        if(!local_chain.empty()){
            local_chains.push_back(local_chain);
        }
    }
    int pivot = DeterministicSelect(medians, (medians.size() + 1) / 2, local_active, local_chains, dag, 0);

    /*partition*/
    std::unordered_set<int> left_set, right_set;
    dag.Reachable(pivot, left_set, right_set, active);
    
    std::vector<std::vector<int>> left_chains, right_chains;
    for (const auto &chain : chains) {
        std::vector<int> left_chain, right_chain;
        int position = BinarySearch(chain, pivot, 1);
        if (position < chain.size() && chain[position] == pivot) {
            left_chain.insert(left_chain.end(), chain.begin(), chain.begin() + position);
            right_chain.insert(right_chain.end(), chain.begin() + position + 1, chain.end());
        } else {
            left_chain.insert(left_chain.end(), chain.begin(), chain.begin() + position);
            right_chain.insert(right_chain.end(), chain.begin() + position, chain.end());
        }

        left_set.insert(left_chain.begin(), left_chain.end());
        right_set.insert(right_chain.begin(), right_chain.end());
            
        if(!left_chain.empty()){
            left_chains.push_back(left_chain);
        }
        if(!right_chain.empty()){
            right_chains.push_back(right_chain);
        }
    }
    active[pivot] = false;
    if (left_set.size() >= k) {
        for (int v : right_set) {
            active[v] = false;
        }
        std::vector<int> left(left_set.begin(), left_set.end());
        if(flag == 1){
            return DeterministicSelect(left, k, active, left_chains, dag, 1);
        }
        else{
            return DeterministicSelect(left, k, active, left_chains, dag, 0);
        }
    }
    else if (left_set.size() + 1 < k) {
        for (int v : left_set) {
            active[v] = false;
        }
        std::vector<int> right(right_set.begin(), right_set.end());
        if(flag == 1){
            return DeterministicSelect(right, k - (left_set.size() + 1), active, right_chains, dag, 1);
        }
        else{
            return DeterministicSelect(right, k - (left_set.size() + 1), active, right_chains, dag, 0);
        }
    }
    else {
        return pivot;
    }
}


int SelectWithPrediction::GetComparisonCount() const {
    return comparison_cnt_dag;
}
int SelectWithPrediction::GetComparisonCountSort() const {
    return comparison_cnt_sort;
}